/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This file is a part of the CANN Open Software.
 * Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#include "lib/reduce/reduce_tiling.h"

#include <string>
#include <cstdint>
#include <algorithm>

#include "graph/tensor.h"
#include "impl/host_log.h"
namespace AscendC {
namespace {
constexpr uint32_t ONE_BLK_SIZE = 32;
constexpr uint32_t ONE_REPEAT_BYTE_SIZE = 256;
constexpr uint32_t HALF_TYPE_SIZE = 2;
constexpr uint32_t FLOAT_TYPE_SIZE = 4;
constexpr uint32_t ALLOWED_SHAPE_DIM = 2;
constexpr uint32_t B32_ELEM_NUM_PER_REPEAT = 64;

uint32_t GetTypeSize(const ge::DataType dataType)
{
    if (dataType == ge::DT_FLOAT) {
        return FLOAT_TYPE_SIZE;
    } else if (dataType == ge::DT_FLOAT16) {
        return HALF_TYPE_SIZE;
    }
    return 1;
}
// Find the most closest power of two results.
uint32_t FindK(uint32_t n) {
    uint32_t ret = 1U;
    while (n > 1U) {
        ret <<= 1U;
        n >>= 1U;
    }
    return ret;
}

inline void CheckParams(std::vector<int64_t> shapeDims, bool isSrcInnerPad, ReducePattern pattern,
    uint32_t first, uint32_t last, std::string apiName, std::string funcName)
{
    ASCENDC_HOST_ASSERT(shapeDims.size() == ALLOWED_SHAPE_DIM, return,
        "[%s][%s] srcShape dims must be 2.", apiName.c_str(), funcName.c_str());
    ASCENDC_HOST_ASSERT(isSrcInnerPad, return,
        "[%s][%s] isSrcInnerPad must be true on this platform.", apiName.c_str(), funcName.c_str());
    ASCENDC_HOST_ASSERT(pattern == ReducePattern::AR || pattern == ReducePattern::RA,
        return,
        "[%s][%s] Currently only support AR and RA pattern.", apiName.c_str(), funcName.c_str());
    ASCENDC_HOST_ASSERT(first > 0 && last > 0, return,
        "[%s][%s] both first and last axis must be greater than 0.", apiName.c_str(), funcName.c_str());
}
} // namespace

void GetReduceCommonMaxMinTmpSize(const ge::Shape &srcShape,
                                const ge::DataType dataType,
                                ReducePattern pattern, bool isSrcInnerPad, bool isReuseSource,
                                uint32_t &maxValue, uint32_t &minValue, bool isBinaryAdd,
                                std::string apiName, std::string funcName)
{
    std::vector<int64_t> shapeDims = srcShape.GetDims();

    const uint32_t first = static_cast<uint32_t>(shapeDims[0]);
    const uint32_t last = static_cast<uint32_t>(shapeDims[1]);
    CheckParams(shapeDims, isSrcInnerPad, pattern, first, last, apiName, funcName);
    if (isReuseSource) {
        maxValue = minValue = 0U;
        return;
    }
    if (pattern == ReducePattern::AR) {
        if (isBinaryAdd) {
            uint32_t k = FindK(last);
            if (k == last && k > 1u) {
                k >>= 1u;
            }
            maxValue = minValue = (k * GetTypeSize(dataType) + ONE_BLK_SIZE - 1u) / ONE_BLK_SIZE * ONE_BLK_SIZE;
        } else {
            uint32_t elePerBlk = ONE_BLK_SIZE / GetTypeSize(dataType);
            uint32_t elePerRep = ONE_REPEAT_BYTE_SIZE / GetTypeSize(dataType);
            if (last <= elePerBlk) {
                maxValue = minValue = 0u;
            } else if (last > elePerBlk && last < elePerRep) {
                maxValue = minValue = first * elePerBlk * GetTypeSize(dataType);
            } else {
                maxValue = minValue = first * elePerRep * GetTypeSize(dataType);
            }
        }
        return;
    }
    uint32_t k = FindK(first);
    if (k == first && first > 1U) {
        k >>= 1U;
    }
    maxValue = minValue = k * ((last * GetTypeSize(dataType) + ONE_BLK_SIZE - 1u) / ONE_BLK_SIZE * ONE_BLK_SIZE);
}

inline void GetReduceSumMeanCommonTmpSize(const ge::Shape &srcShape,
                               const ge::DataType dataType,
                               ReducePattern pattern, bool isSrcInnerPad, bool isReuseSource,
                               uint32_t &maxValue, uint32_t &minValue, std::string apiName, std::string funcName)
{
    std::vector<int64_t> shapeDims = srcShape.GetDims();
    const uint32_t first = static_cast<uint32_t>(shapeDims[0]);
    const uint32_t last = static_cast<uint32_t>(shapeDims[1]);
    CheckParams(shapeDims, isSrcInnerPad, pattern, first, last, apiName, funcName);
    if (isReuseSource) {
        maxValue = minValue = 0U;
        return;
    }
    uint32_t elePerBlk = ONE_BLK_SIZE / FLOAT_TYPE_SIZE;
    if (pattern == ReducePattern::AR) {
        uint32_t k = FindK(last);
        if (k == last && first > 1U) {
            k >>= 1U;
        }
        if (last <= B32_ELEM_NUM_PER_REPEAT) {
            maxValue = minValue = 0U;
        } else {
            maxValue = minValue = (first * k) * FLOAT_TYPE_SIZE;
        }
    } else {
        uint32_t k = FindK(first);
        uint32_t padLast = (last + elePerBlk - 1U) / elePerBlk * elePerBlk;
        if (first == k && first > 1U) {
            k >>= 1U;
        }
        maxValue = minValue = (k * padLast) * FLOAT_TYPE_SIZE;
    }
    return;
}

inline void GetReduceAnyAllCommonTmpSize(const ge::Shape &srcShape,
                                const ge::DataType dataType,
                                ReducePattern pattern, bool isSrcInnerPad, bool isReuseSource,
                                uint32_t &maxValue, uint32_t &minValue, std::string apiName, std::string funcName)
{
    std::vector<int64_t> shapeDims = srcShape.GetDims();
    const uint32_t first = static_cast<uint32_t>(shapeDims[0]);
    const uint32_t last = static_cast<uint32_t>(shapeDims[1]);
    CheckParams(shapeDims, isSrcInnerPad, pattern, first, last, apiName, funcName);
    if (pattern == ReducePattern::AR) {
        uint32_t elePerBlk = static_cast<uint32_t>(ONE_BLK_SIZE / sizeof(uint8_t));
        uint32_t padLast = (last + elePerBlk - 1U) / elePerBlk * elePerBlk;
        minValue = maxValue = static_cast<uint32_t>(padLast * sizeof(uint16_t)) + (first * elePerBlk);
    } else {
        if (isReuseSource) {
            maxValue = minValue = 0U;
            return;
        }
        uint32_t k = FindK(first);
        if (k == first && first > 1U) {
            k >>= 1U;
        }
        maxValue = minValue = k * ((last + ONE_BLK_SIZE - 1U) / ONE_BLK_SIZE * ONE_BLK_SIZE);
    }
    return;
}

void GetReduceProdMaxMinTmpSize(const ge::Shape &srcShape,
                                const ge::DataType dataType,
                                ReducePattern pattern, bool isSrcInnerPad, bool isReuseSource,
                                uint32_t &maxValue, uint32_t &minValue) 
{
    ASCENDC_HOST_ASSERT(dataType == ge::DT_FLOAT, return,
        "[ReduceProd][GetReduceProdMaxMinTmpSize] it only supports float type on this platform.");
    std::vector<int64_t> shapeDims = srcShape.GetDims();
    const uint32_t first = static_cast<uint32_t>(shapeDims[0]);
    const uint32_t last = static_cast<uint32_t>(shapeDims[1]);
    CheckParams(shapeDims, isSrcInnerPad, pattern, first, last, "ReduceProd", "GetReduceProdMaxMinTmpSize");
    if (isReuseSource) {
        minValue = pattern == ReducePattern::AR ? ONE_REPEAT_BYTE_SIZE : 0U;
        maxValue = minValue;
        return;
    }
    uint32_t elePerBlk = ONE_BLK_SIZE / FLOAT_TYPE_SIZE;
    if (pattern == ReducePattern::AR) {
        uint32_t k = FindK(last);
        if (k == last && first > 1U) {
            k >>= 1U;
        }
        uint32_t blkAlignK = elePerBlk > k ? elePerBlk : k;
        maxValue = minValue = (blkAlignK + first * elePerBlk) * FLOAT_TYPE_SIZE + ONE_REPEAT_BYTE_SIZE;
        return;
    }
    uint32_t k = FindK(first);
    if (k == first && first > 1U) {
        k >>= 1U;
    }
    maxValue = minValue = k * ((last * GetTypeSize(dataType) + ONE_BLK_SIZE - 1u) / ONE_BLK_SIZE * ONE_BLK_SIZE);
}

void GetReduceMaxMaxMinTmpSize(const ge::Shape &srcShape,
                                const ge::DataType dataType,
                                ReducePattern pattern, bool isSrcInnerPad, bool isReuseSource,
                                uint32_t &maxValue, uint32_t &minValue) 
{
    ASCENDC_HOST_ASSERT(dataType == ge::DT_FLOAT || dataType == ge::DT_FLOAT16,
        return,
        "[ReduceMax][GetReduceMaxMaxMinTmpSize] it only supports float and half type on this platform.");
    GetReduceCommonMaxMinTmpSize(srcShape, dataType, pattern, isSrcInnerPad, isReuseSource, maxValue, minValue, false,
        "ReduceMax", "GetReduceMaxMaxMinTmpSize");
}

void GetReduceMinMaxMinTmpSize(const ge::Shape &srcShape,
                                const ge::DataType dataType,
                                ReducePattern pattern, bool isSrcInnerPad, bool isReuseSource,
                                uint32_t &maxValue, uint32_t &minValue) 
{
    ASCENDC_HOST_ASSERT(dataType == ge::DT_FLOAT || dataType == ge::DT_FLOAT16,
        return,
        "[ReduceMin][GetReduceMinMaxMinTmpSize] it only supports float and half type on this platform.");
    GetReduceCommonMaxMinTmpSize(srcShape, dataType, pattern, isSrcInnerPad, isReuseSource, maxValue, minValue, false,
        "ReduceMin", "GetReduceMinMaxMinTmpSize");
}

void GetReduceAnyMaxMinTmpSize(const ge::Shape &srcShape,
                                const ge::DataType dataType,
                                ReducePattern pattern, bool isSrcInnerPad, bool isReuseSource,
                                uint32_t &maxValue, uint32_t &minValue) 
{
    ASCENDC_HOST_ASSERT(dataType == ge::DT_FLOAT || dataType == ge::DT_UINT8,
        return,
        "[ReduceAny][GetReduceAnyMaxMinTmpSize] it only supports float and uint8_t type on this platform.");
    if (dataType == ge::DT_UINT8) {
        GetReduceAnyAllCommonTmpSize(srcShape, dataType, pattern, isSrcInnerPad, isReuseSource, maxValue, minValue,
            "ReduceAny", "GetReduceAnyMaxMinTmpSize");
    } else {
        GetReduceCommonMaxMinTmpSize(srcShape, dataType, pattern, isSrcInnerPad, isReuseSource, maxValue, minValue,
            false, "ReduceAny", "GetReduceAnyMaxMinTmpSize");
    }
}

void GetReduceAllMaxMinTmpSize(const ge::Shape &srcShape,
                               const ge::DataType dataType,
                               ReducePattern pattern, bool isSrcInnerPad, bool isReuseSource,
                               uint32_t &maxValue, uint32_t &minValue)
{
    ASCENDC_HOST_ASSERT((dataType == ge::DT_FLOAT || dataType == ge::DT_UINT8), return,
        "[ReduceAll][GetReduceAllMaxMinTmpSize] it only supports float and uint8 type on this platform.");
    if (dataType == ge::DT_UINT8) {
        GetReduceAnyAllCommonTmpSize(srcShape, dataType, pattern, isSrcInnerPad, isReuseSource, maxValue, minValue,
            "ReduceAll", "GetReduceAllMaxMinTmpSize");
    } else {
        GetReduceCommonMaxMinTmpSize(srcShape, dataType, pattern, isSrcInnerPad, isReuseSource, maxValue, minValue,
            false, "ReduceAll", "GetReduceAllMaxMinTmpSize");
    }
}

void GetReduceSumMaxMinTmpSize(const ge::Shape &srcShape,
                               const ge::DataType dataType,
                               ReducePattern pattern, bool isSrcInnerPad, bool isReuseSource,
                               uint32_t &maxValue, uint32_t &minValue)
{
    ASCENDC_HOST_ASSERT(dataType == ge::DT_FLOAT, return,
        "[ReduceSum][GetReduceSumMaxMinTmpSize] it only supports float type on this platform.");
    GetReduceSumMeanCommonTmpSize(srcShape, dataType, pattern, isSrcInnerPad, isReuseSource, maxValue, minValue,
        "ReduceSum", "GetReduceSumMaxMinTmpSize");
}

void GetReduceMeanMaxMinTmpSize(const ge::Shape &srcShape,
                                const ge::DataType dataType,
                                ReducePattern pattern, bool isSrcInnerPad, bool isReuseSource,
                                uint32_t &maxValue, uint32_t &minValue)
{
    ASCENDC_HOST_ASSERT(dataType == ge::DT_FLOAT, return,
        "[ReduceMean][GetReduceMeanMaxMinTmpSize] it only supports float type on this platform.");
    GetReduceSumMeanCommonTmpSize(srcShape, dataType, pattern, isSrcInnerPad, isReuseSource, maxValue, minValue,
        "ReduceMean", "GetReduceMeanMaxMinTmpSize");
}
}  // namespace AscendC
